{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Nifti Read Example\n",
    "\n",
    "The purpose of this notebook is to illustrate reading Nifti files and iterating over patches of the volumes loaded from them.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/MONAI/blob/master/examples/notebooks/nifti_read_example.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "Note: you may need to restart the kernel to use updated packages.\n"
    }
   ],
   "source": [
    "%pip install -qU \"monai[nibabel]\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "MONAI version: 0.2.0\nPython version: 3.7.5 (default, Nov  7 2019, 10:50:52)  [GCC 8.3.0]\nNumpy version: 1.19.1\nPytorch version: 1.6.0\n\nOptional dependencies:\nPytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.\nNibabel version: 3.1.1\nscikit-image version: NOT INSTALLED or UNKNOWN VERSION.\nPillow version: NOT INSTALLED or UNKNOWN VERSION.\nTensorboard version: NOT INSTALLED or UNKNOWN VERSION.\n\nFor details about installing the optional dependencies, please visit:\n    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n\n"
    }
   ],
   "source": [
    "# Copyright 2020 MONAI Consortium\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "\n",
    "import glob\n",
    "import os\n",
    "import shutil\n",
    "import tempfile\n",
    "\n",
    "import nibabel as nib\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from monai.config import print_config\n",
    "from monai.data import ArrayDataset, GridPatchDataset, create_test_image_3d\n",
    "from monai.transforms import (\n",
    "    AddChannel,\n",
    "    Compose,\n",
    "    LoadNifti,\n",
    "    RandSpatialCrop,\n",
    "    ScaleIntensity,\n",
    "    ToTensor,\n",
    ")\n",
    "from monai.utils import first\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup data directory\n",
    "\n",
    "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  \n",
    "This allows you to save results and reuse downloads.  \n",
    "If not specified a temporary directory will be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "/home/bengorman/notebooks/\n"
    }
   ],
   "source": [
    "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
    "root_dir = tempfile.mkdtemp if directory is None else directory\n",
    "print(root_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a number of test Nifti files:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(5):\n",
    "    im, seg = create_test_image_3d(128, 128, 128)\n",
    "\n",
    "    n = nib.Nifti1Image(im, np.eye(4))\n",
    "    nib.save(n, os.path.join(root_dir, f\"im{i}.nii.gz\"))\n",
    "\n",
    "    n = nib.Nifti1Image(seg, np.eye(4))\n",
    "    nib.save(n, os.path.join(root_dir, f\"seg{i}.nii.gz\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a data loader which yields uniform random patches from loaded Nifti files:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "torch.Size([5, 1, 64, 64, 64]) torch.Size([5, 1, 64, 64, 64])\n"
    }
   ],
   "source": [
    "images = sorted(glob.glob(os.path.join(root_dir, \"im*.nii.gz\")))\n",
    "segs = sorted(glob.glob(os.path.join(root_dir, \"seg*.nii.gz\")))\n",
    "\n",
    "imtrans = Compose(\n",
    "    [\n",
    "        LoadNifti(image_only=True),\n",
    "        ScaleIntensity(),\n",
    "        AddChannel(),\n",
    "        RandSpatialCrop((64, 64, 64), random_size=False),\n",
    "        ToTensor(),\n",
    "    ]\n",
    ")\n",
    "\n",
    "segtrans = Compose(\n",
    "    [\n",
    "        LoadNifti(image_only=True),\n",
    "        AddChannel(),\n",
    "        RandSpatialCrop((64, 64, 64), random_size=False),\n",
    "        ToTensor(),\n",
    "    ]\n",
    ")\n",
    "\n",
    "ds = ArrayDataset(images, imtrans, segs, segtrans)\n",
    "\n",
    "loader = torch.utils.data.DataLoader(\n",
    "    ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()\n",
    ")\n",
    "im, seg = first(loader)\n",
    "print(im.shape, seg.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Alternatively create a data loader which yields patches in regular grid order from loaded images:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "torch.Size([10, 1, 64, 64, 64]) torch.Size([10, 1, 64, 64, 64])\n"
    }
   ],
   "source": [
    "imtrans = Compose([LoadNifti(image_only=True), ScaleIntensity(), AddChannel(), ToTensor()])\n",
    "\n",
    "segtrans = Compose([LoadNifti(image_only=True), AddChannel(), ToTensor()])\n",
    "\n",
    "ds = ArrayDataset(images, imtrans, segs, segtrans)\n",
    "ds = GridPatchDataset(ds, (64, 64, 64))\n",
    "\n",
    "loader = torch.utils.data.DataLoader(\n",
    "    ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()\n",
    ")\n",
    "im, seg = first(loader)\n",
    "print(im.shape, seg.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleanup data directory\n",
    "\n",
    "Remove directory if a temporary was used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "if directory is None:\n",
    "    shutil.rmtree(root_dir)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
